{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7342e20d",
   "metadata": {},
   "source": [
    "# Demo: Using Embeddings in Data Processing with Daft\n",
    "\n",
    "Daft makes working with complex data easy. This demonstration will show an **end-to-end example of data processing with embeddings in Daft**. We will:\n",
    "1. load a large dataset, \n",
    "2. compute embeddings, \n",
    "3. use the embeddings for semantic similarity, and \n",
    "4. inspect the results and then and write them out.\n",
    "\n",
    "----\n",
    "\n",
    "**Scenario:** There are a lot of questions on StackExchange. Unfortunately, the vast majority of questions do not have high visibility or good answers. But, a similar question with a high quality answer may already exist elsewhere on the site.\n",
    "\n",
    "We would like to go through all the questions on StackExchange and **associate each question with another question that is highly rated**.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e59162ca",
   "metadata": {},
   "source": [
    "## Step 0: Dependencies and configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7154cd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install Daft!\n",
    "!pip install 'daft[aws]'\n",
    "\n",
    "# We will use sentence-transformers for computing embeddings.\n",
    "!pip install sentence-transformers accelerate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f40dafa",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "CI = False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a45a28c",
   "metadata": {},
   "source": [
    "## Step 1: Load the dataset\n",
    "\n",
    "We will use the **StackExchange crawl from the [RedPajamas dataset](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T)**. It is 75GB of `jsonl` files. \n",
    "\n",
    "*EDIT (June 2023): Our hosted version of the full dataset is temporarily unavailable. Please enjoy the demo with the sample dataset for now.*\n",
    "\n",
    "**Note:** This demo runs best on a cluster with many GPUs available. Information on how to connect Daft to a cluster is available [here](https://www.getdaft.io/projects/docs/en/stable/learn/user_guides/poweruser/distributed-computing.html). \n",
    "\n",
    "If running on a single node, you can use the provided subsample of the data, which is 75MB in size. If you like, you can also truncate either dataset to a desired number of rows using `df.limit`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4e24aa5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-05-16 18:23:22.342 | WARNING  | daft.filesystem:_get_s3fs_kwargs:47 - AWS credentials not found - using anonymous access to S3 which will fail if the bucket you are accessing is not a public bucket. See boto3 documentation for more details on configuring your AWS credentials: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html\n",
      "2023-05-16 18:23:24.673 | WARNING  | daft.filesystem:_get_s3fs_kwargs:47 - AWS credentials not found - using anonymous access to S3 which will fail if the bucket you are accessing is not a public bucket. See boto3 documentation for more details on configuring your AWS credentials: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------+--------------------------------------------------------------------------------------------------------------+\n",
      "| text   | meta                                                                                                         |\n",
      "| Utf8   | Struct[language: Utf8, url: Utf8, timestamp: Timestamp(Seconds, None), source: Utf8, question_score: Utf8]   |\n",
      "+========+==============================================================================================================+\n",
      "+--------+--------------------------------------------------------------------------------------------------------------+\n",
      "(No data to display: Dataframe not materialized)\n"
     ]
    }
   ],
   "source": [
    "import daft\n",
    "\n",
    "SAMPLE_DATA_PATH = \"s3://daft-public-data/redpajama-1t-sample/stackexchange_sample.jsonl\"\n",
    "IO_CONFIG = daft.io.IOConfig(\n",
    "    s3=daft.io.S3Config(anonymous=True, region_name=\"us-west-2\")\n",
    ")  # Use anonymous-mode for accessing AWS S3\n",
    "\n",
    "df = daft.read_json(SAMPLE_DATA_PATH, io_config=IO_CONFIG)\n",
    "\n",
    "if CI:\n",
    "    df = df.limit(10)\n",
    "\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9a64086",
   "metadata": {},
   "source": [
    "## Step 2: Compute embeddings\n",
    "\n",
    "We can see there is a `text` column that holds the question answer text and a `meta` column with metadata.\n",
    "\n",
    "Let's **compute the embeddings of our text**. We start by putting our model (SentenceTransformers) into a **Daft UDF**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "98ab9504",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "MODEL_NAME = \"all-MiniLM-L6-v2\"\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "\n",
    "@daft.udf(return_dtype=daft.DataType.python())\n",
    "class EncodingUDF:\n",
    "    def __init__(self):\n",
    "        from sentence_transformers import SentenceTransformer\n",
    "\n",
    "        self.model = SentenceTransformer(MODEL_NAME, device=device)\n",
    "\n",
    "    def __call__(self, text_col):\n",
    "        return [self.model.encode(text, convert_to_tensor=True) for text in text_col.to_pylist()]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "462b2f5f",
   "metadata": {},
   "source": [
    "Then, we can just call the UDF to run the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "a96b8ee9",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.with_column(\"embedding\", EncodingUDF(df[\"text\"]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43af0ae0",
   "metadata": {},
   "source": [
    "Pause and notice how easy it was to write this. \n",
    "\n",
    "In particular, we are not forced to do any sort of unwieldy type coercion on the embedding result; instead, we can return the result as-is, whatever it is, even when running Daft in a cluster. Daft dataframes can hold a wide range of data types and also grants you the flexibility of Python's dynamic typing when necessary.\n",
    "\n",
    "Next, let's also **extract the URL and score**:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48c37423",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.select(\n",
    "    df[\"embedding\"],\n",
    "    df[\"meta\"].struct.get(\"url\"),\n",
    "    df[\"meta\"].struct.get(\"question_score\"),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1aaf355e",
   "metadata": {},
   "source": [
    "and wait for all the results to finish computing:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "a1e873bc",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-05-16 18:23:29.522 | WARNING  | daft.filesystem:_get_s3fs_kwargs:47 - AWS credentials not found - using anonymous access to S3 which will fail if the bucket you are accessing is not a public bucket. See boto3 documentation for more details on configuring your AWS credentials: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Embeddings complete!\n"
     ]
    }
   ],
   "source": [
    "df = df.collect()\n",
    "embeddings_df = df\n",
    "print(\"Embeddings complete!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93e5e854",
   "metadata": {},
   "source": [
    "## Step 3: Semantic similarity\n",
    "\n",
    "Let's **get the top questions**. We will use `df.sort` to sort by score and then `df.limit` to grab some fraction of the top."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "6c8fd74a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "NUM_TOP_QUESTIONS = math.ceil(math.sqrt(len(df)))\n",
    "\n",
    "top_questions = (df.sort(df[\"question_score\"], desc=True).limit(NUM_TOP_QUESTIONS)).to_pydict()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf8c2b78",
   "metadata": {},
   "source": [
    "Now we will **take each regular question** and **find a related top question**. For this we will need to do a similarity search. Let's do that within a Daft UDF."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "f3e3943c",
   "metadata": {},
   "outputs": [],
   "source": [
    "@daft.udf(\n",
    "    return_dtype=daft.DataType.struct(\n",
    "        {\n",
    "            \"related_top_question\": daft.DataType.string(),\n",
    "            \"similarity\": daft.DataType.float64(),\n",
    "        }\n",
    "    )\n",
    ")\n",
    "def similarity_search(embedding_col, top_embeddings, top_urls):\n",
    "    if len(embedding_col) == 0:\n",
    "        return []\n",
    "\n",
    "    from sentence_transformers import util\n",
    "\n",
    "    # Tensor prep\n",
    "    query_embedding_t = torch.stack(embedding_col.to_pylist())\n",
    "    if torch.cuda.is_available():\n",
    "        query_embedding_t = query_embedding_t.to(\"cuda\")\n",
    "        top_embeddings = top_embeddings.to(\"cuda\")\n",
    "\n",
    "    # Do semantic search\n",
    "    results = util.semantic_search(query_embedding_t, top_embeddings, top_k=1)\n",
    "\n",
    "    # Extract URL and score from search results\n",
    "    results = [res[0] for res in results]\n",
    "    results = [\n",
    "        {\n",
    "            \"related_top_question\": top_urls[res[\"corpus_id\"]],\n",
    "            \"similarity\": res[\"score\"],\n",
    "        }\n",
    "        for res in results\n",
    "    ]\n",
    "    return results\n",
    "\n",
    "\n",
    "import torch\n",
    "\n",
    "df = df.with_column(\n",
    "    \"search_result\",\n",
    "    similarity_search(\n",
    "        df[\"embedding\"],\n",
    "        top_embeddings=torch.stack(top_questions[\"embedding\"]),\n",
    "        top_urls=top_questions[\"url\"],\n",
    "    ),\n",
    ")\n",
    "\n",
    "df = df.select(\n",
    "    df[\"url\"],\n",
    "    df[\"question_score\"],\n",
    "    df[\"search_result\"][\"related_top_question\"].alias(\"related_top_question\"),\n",
    "    df[\"search_result\"][\"similarity\"].alias(\"similarity\"),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3859d307",
   "metadata": {},
   "source": [
    "## Step 4: Inspect and write results\n",
    "\n",
    "Did the matching work well? Let's take a peek at our best results to see if they make sense."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "49812f5c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "    <table class=\"dataframe\">\n",
       "<thead>\n",
       "<tr><th>URL<br>Utf8                                   </th><th style=\"text-align: right;\">  question_score<br>Int64</th><th>related_top_question<br>Utf8                </th><th style=\"text-align: right;\">  similarity<br>Float64</th></tr>\n",
       "</thead>\n",
       "<tbody>\n",
       "<tr><td>https://stackoverflow.com/questions/72682333  </td><td style=\"text-align: right;\">                       -1</td><td>https://askubuntu.com/questions/401449      </td><td style=\"text-align: right;\">               0.821935</td></tr>\n",
       "<tr><td>https://dba.stackexchange.com/questions/185574</td><td style=\"text-align: right;\">                        1</td><td>https://askubuntu.com/questions/401449      </td><td style=\"text-align: right;\">               0.783146</td></tr>\n",
       "<tr><td>https://stackoverflow.com/questions/23113375  </td><td style=\"text-align: right;\">                        0</td><td>https://stackoverflow.com/questions/9907682 </td><td style=\"text-align: right;\">               0.779253</td></tr>\n",
       "<tr><td>https://stackoverflow.com/questions/68984510  </td><td style=\"text-align: right;\">                        0</td><td>https://askubuntu.com/questions/401449      </td><td style=\"text-align: right;\">               0.770626</td></tr>\n",
       "<tr><td>https://stackoverflow.com/questions/72643833  </td><td style=\"text-align: right;\">                        1</td><td>https://stackoverflow.com/questions/34030373</td><td style=\"text-align: right;\">               0.740242</td></tr>\n",
       "<tr><td>https://stackoverflow.com/questions/6092305   </td><td style=\"text-align: right;\">                        1</td><td>https://stackoverflow.com/questions/9907682 </td><td style=\"text-align: right;\">               0.727464</td></tr>\n",
       "<tr><td>https://stackoverflow.com/questions/14266910  </td><td style=\"text-align: right;\">                        1</td><td>https://stackoverflow.com/questions/4690758 </td><td style=\"text-align: right;\">               0.723069</td></tr>\n",
       "<tr><td>https://stackoverflow.com/questions/60891048  </td><td style=\"text-align: right;\">                        1</td><td>https://stackoverflow.com/questions/1388818 </td><td style=\"text-align: right;\">               0.711756</td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "    <small>(Showing first 8 rows)</small>\n",
       "</div>"
      ],
      "text/plain": [
       "+----------------------+------------------+------------------------+--------------+\n",
       "| URL                  |   question_score | related_top_question   |   similarity |\n",
       "| Utf8                 |            Int64 | Utf8                   |      Float64 |\n",
       "+======================+==================+========================+==============+\n",
       "| https://stackoverflo |               -1 | https://askubuntu.co   |     0.821935 |\n",
       "| w.com/questions/7268 |                  | m/questions/401449     |              |\n",
       "| 2333                 |                  |                        |              |\n",
       "+----------------------+------------------+------------------------+--------------+\n",
       "| https://dba.stackexc |                1 | https://askubuntu.co   |     0.783146 |\n",
       "| hange.com/questions/ |                  | m/questions/401449     |              |\n",
       "| 185574               |                  |                        |              |\n",
       "+----------------------+------------------+------------------------+--------------+\n",
       "| https://stackoverflo |                0 | https://stackoverflo   |     0.779253 |\n",
       "| w.com/questions/2311 |                  | w.com/questions/9907   |              |\n",
       "| 3375                 |                  | 682                    |              |\n",
       "+----------------------+------------------+------------------------+--------------+\n",
       "| https://stackoverflo |                0 | https://askubuntu.co   |     0.770626 |\n",
       "| w.com/questions/6898 |                  | m/questions/401449     |              |\n",
       "| 4510                 |                  |                        |              |\n",
       "+----------------------+------------------+------------------------+--------------+\n",
       "| https://stackoverflo |                1 | https://stackoverflo   |     0.740242 |\n",
       "| w.com/questions/7264 |                  | w.com/questions/3403   |              |\n",
       "| 3833                 |                  | 0373                   |              |\n",
       "+----------------------+------------------+------------------------+--------------+\n",
       "| https://stackoverflo |                1 | https://stackoverflo   |     0.727464 |\n",
       "| w.com/questions/6092 |                  | w.com/questions/9907   |              |\n",
       "| 305                  |                  | 682                    |              |\n",
       "+----------------------+------------------+------------------------+--------------+\n",
       "| https://stackoverflo |                1 | https://stackoverflo   |     0.723069 |\n",
       "| w.com/questions/1426 |                  | w.com/questions/4690   |              |\n",
       "| 6910                 |                  | 758                    |              |\n",
       "+----------------------+------------------+------------------------+--------------+\n",
       "| https://stackoverflo |                1 | https://stackoverflo   |     0.711756 |\n",
       "| w.com/questions/6089 |                  | w.com/questions/1388   |              |\n",
       "| 1048                 |                  | 818                    |              |\n",
       "+----------------------+------------------+------------------------+--------------+\n",
       "(Showing first 8 rows)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = df.where(df[\"similarity\"] < 0.99)  # To ignore duplicate questions.\n",
    "df = df.sort(df[\"similarity\"], desc=True)\n",
    "df.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de7112e7",
   "metadata": {},
   "source": [
    "On the left hand side is an average question without much activity. The link in the right hand side contains a similar question that already has some high quality answers. Success!\n",
    "\n",
    "Finally, we will probably want to save the results for future use. Let's write them out to parquet files locally."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "5537c090",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'file_path': ['question_matches.pq/6ee09e59-dbea-495d-894f-0f182567035d-0.parquet']}"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.write_parquet(\"question_matches.pq\").to_pydict()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f714f9d",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "We have shown a simple example of a complex data processing workflow. It involved typical tabular operations like sort and filter. But, we also had to do interleave some pretty interesting things with complex data: we created, stored, and searched across embeddings. \n",
    "\n",
    "**Daft** is a data processing framework that allows you to do express these things easily, while also scaling up to large clusters right out of the box.\n",
    "\n",
    "# You can get daft at [getdaft.io](https://getdaft.io)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
